{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Shows how one can generate text given a prompt and some hyperparameters, using either minGPT or huggingface/transformers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from transformers import GPT2Tokenizer, GPT2LMHeadModel\n",
    "from mingpt.model import GPT\n",
    "from mingpt.utils import set_seed\n",
    "from mingpt.bpe import BPETokenizer\n",
    "set_seed(3407)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "use_mingpt = True # use minGPT or huggingface/transformers model?\n",
    "model_type = 'gpt2-xl'\n",
    "device = 'cuda'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "number of parameters: 1557.61M\n"
     ]
    }
   ],
   "source": [
    "if use_mingpt:\n",
    "    model = GPT.from_pretrained(model_type)\n",
    "else:\n",
    "    model = GPT2LMHeadModel.from_pretrained(model_type)\n",
    "    model.config.pad_token_id = model.config.eos_token_id # suppress a warning\n",
    "\n",
    "# ship model to device and set to eval mode\n",
    "model.to(device)\n",
    "model.eval();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def generate(prompt='', num_samples=10, steps=20, do_sample=True):\n",
    "        \n",
    "    # tokenize the input prompt into integer input sequence\n",
    "    if use_mingpt:\n",
    "        tokenizer = BPETokenizer()\n",
    "        if prompt == '':\n",
    "            # to create unconditional samples...\n",
    "            # manually create a tensor with only the special <|endoftext|> token\n",
    "            # similar to what openai's code does here https://github.com/openai/gpt-2/blob/master/src/generate_unconditional_samples.py\n",
    "            x = torch.tensor([[tokenizer.encoder.encoder['<|endoftext|>']]], dtype=torch.long)\n",
    "        else:\n",
    "            x = tokenizer(prompt).to(device)\n",
    "    else:\n",
    "        tokenizer = GPT2Tokenizer.from_pretrained(model_type)\n",
    "        if prompt == '': \n",
    "            # to create unconditional samples...\n",
    "            # huggingface/transformers tokenizer special cases these strings\n",
    "            prompt = '<|endoftext|>'\n",
    "        encoded_input = tokenizer(prompt, return_tensors='pt').to(device)\n",
    "        x = encoded_input['input_ids']\n",
    "    \n",
    "    # we'll process all desired num_samples in a batch, so expand out the batch dim\n",
    "    x = x.expand(num_samples, -1)\n",
    "\n",
    "    # forward the model `steps` times to get samples, in a batch\n",
    "    y = model.generate(x, max_new_tokens=steps, do_sample=do_sample, top_k=40)\n",
    "    \n",
    "    for i in range(num_samples):\n",
    "        out = tokenizer.decode(y[i].cpu().squeeze())\n",
    "        print('-'*80)\n",
    "        print(out)\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------------------------------------------------------------------------------\n",
      "Andrej Karpathy, the chief of the criminal investigation department, said during a news conference, \"We still have a lot of\n",
      "--------------------------------------------------------------------------------\n",
      "Andrej Karpathy, the man whom most of America believes is the architect of the current financial crisis. He runs the National Council\n",
      "--------------------------------------------------------------------------------\n",
      "Andrej Karpathy, the head of the Department for Regional Reform of Bulgaria and an MP in the centre-right GERB party\n",
      "--------------------------------------------------------------------------------\n",
      "Andrej Karpathy, the former head of the World Bank's IMF department, who worked closely with the IMF. The IMF had\n",
      "--------------------------------------------------------------------------------\n",
      "Andrej Karpathy, the vice president for innovation and research at Citi who oversaw the team's work to make sense of the\n",
      "--------------------------------------------------------------------------------\n",
      "Andrej Karpathy, the CEO of OOAK Research, said that the latest poll indicates that it won't take much to\n",
      "--------------------------------------------------------------------------------\n",
      "Andrej Karpathy, the former prime minister of Estonia was at the helm of a three-party coalition when parliament met earlier this\n",
      "--------------------------------------------------------------------------------\n",
      "Andrej Karpathy, the director of the Institute of Economic and Social Research, said if the rate of return is only 5 per\n",
      "--------------------------------------------------------------------------------\n",
      "Andrej Karpathy, the minister of commerce for Latvia's western neighbour: \"The deal means that our two countries have reached more\n",
      "--------------------------------------------------------------------------------\n",
      "Andrej Karpathy, the state's environmental protection commissioner. \"That's why we have to keep these systems in place.\"\n",
      "\n"
     ]
    }
   ],
   "source": [
    "generate(prompt='Andrej Karpathy, the', num_samples=10, steps=20)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10.4 64-bit",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "3ad933181bd8a04b432d3370b9dc3b0662ad032c4dfaa4e4f1596c548f763858"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
